{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. 非混合精度训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/workspace/mdy/miniforge/envs/mdy/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/mnt/workspace/mdy/miniforge/envs/mdy/lib/python3.10/site-packages/_distutils_hack/__init__.py:53: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'3.2.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from torch.optim import AdamW\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "import triton\n",
    "import triton.language as tl\n",
    "from TritonAdam import TritonAdamW\n",
    "import os\n",
    "os.environ['TRITON_PRINT_AUTOTUNING'] = '1'\n",
    "torch.cuda.empty_cache()\n",
    "triton.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载模型\n",
    "- 加载fp32的模型进行测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = '/mnt/workspace/mdy/models/Qwen2.5-0.5B-Instruct'\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path).cuda()\n",
    "iters = 100\n",
    "for p in model.parameters():\n",
    "    break\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch Adam"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 非Fused版本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:09<00:00, 10.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[-0.1078,  0.1398,  0.0985,  ..., -0.0773, -0.0879, -0.0940],\n",
      "        [ 0.0785, -0.0957,  0.0860,  ..., -0.1042,  0.1048,  0.0868],\n",
      "        [-0.0911,  0.0809, -0.0942,  ...,  0.0930, -0.1123,  0.0945],\n",
      "        ...,\n",
      "        [ 0.1392, -0.1389,  0.1346,  ..., -0.1390, -0.1394,  0.1495],\n",
      "        [ 0.1392, -0.1389,  0.1346,  ..., -0.1390, -0.1394,  0.1495],\n",
      "        [ 0.1392, -0.1389,  0.1346,  ..., -0.1390, -0.1394,  0.1495]],\n",
      "       device='cuda:0', requires_grad=True)\n",
      "25.367843627929688\n"
     ]
    }
   ],
   "source": [
    "optimizer = AdamW(model.parameters(), fused=False)\n",
    "inp_ids = torch.arange(128).reshape(4,-1).cuda()\n",
    "for _ in tqdm(range(iters)):\n",
    "    out = model(inp_ids)\n",
    "    out.logits.mean().backward()\n",
    "    optimizer.step()\n",
    "print(p) # 刷新再跑，p应该差不多\n",
    "ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)\n",
    "print(ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fused版本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:06<00:00, 14.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[-0.1075,  0.1400,  0.0989,  ..., -0.0775, -0.0883, -0.0943],\n",
      "        [ 0.0791, -0.0955,  0.0848,  ..., -0.1041,  0.1041,  0.0858],\n",
      "        [-0.0751,  0.0809, -0.0941,  ...,  0.0931, -0.1123,  0.0944],\n",
      "        ...,\n",
      "        [ 0.1393, -0.1390,  0.1346,  ..., -0.1388, -0.1392,  0.1495],\n",
      "        [ 0.1393, -0.1390,  0.1346,  ..., -0.1388, -0.1392,  0.1495],\n",
      "        [ 0.1393, -0.1390,  0.1346,  ..., -0.1388, -0.1392,  0.1495]],\n",
      "       device='cuda:0', requires_grad=True)\n",
      "8.984747886657715\n"
     ]
    }
   ],
   "source": [
    "optimizer = AdamW(model.parameters(), fused=True)\n",
    "inp_ids = torch.arange(128).reshape(4,-1).cuda()\n",
    "for _ in tqdm(range(iters)):\n",
    "    out = model(inp_ids)\n",
    "    out.logits.mean().backward()\n",
    "    optimizer.step()\n",
    "print(p) # 刷新再跑，p应该差不多\n",
    "ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)\n",
    "print(ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Triton Adam"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 全部fp32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "finish_custom_init, p_dtype: torch.float32, master_p_dtype: None\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:07<00:00, 12.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[-0.1081,  0.1399,  0.0991,  ..., -0.0780, -0.0882, -0.0937],\n",
      "        [ 0.0789, -0.0957,  0.0858,  ..., -0.1044,  0.1044,  0.0871],\n",
      "        [-0.0793,  0.0809, -0.0939,  ...,  0.0929, -0.1125,  0.0944],\n",
      "        ...,\n",
      "        [ 0.1391, -0.1389,  0.1348,  ..., -0.1391, -0.1395,  0.1495],\n",
      "        [ 0.1391, -0.1389,  0.1348,  ..., -0.1391, -0.1395,  0.1495],\n",
      "        [ 0.1391, -0.1389,  0.1348,  ..., -0.1391, -0.1395,  0.1495]],\n",
      "       device='cuda:0', requires_grad=True)\n",
      "8.832446098327637\n"
     ]
    }
   ],
   "source": [
    "optimizer = TritonAdamW(model.parameters())\n",
    "torch.cuda.empty_cache()\n",
    "inp_ids = torch.arange(128).reshape(4,-1).cuda()\n",
    "for _ in tqdm(range(iters)):\n",
    "    out = model(inp_ids)\n",
    "    out.logits.mean().backward()\n",
    "    optimizer.step()\n",
    "print(p) # 刷新再跑，p应该差不多\n",
    "ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)\n",
    "print(ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1阶2阶动量为bf16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|▎         | 3/100 [00:00<00:08, 12.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "finish_custom_init, p_dtype: torch.float32, master_p_dtype: None\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▌         | 5/100 [00:00<00:07, 13.55it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:06<00:00, 15.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[-0.1063,  0.1401,  0.0998,  ..., -0.0788, -0.0872, -0.0942],\n",
      "        [ 0.0785, -0.0951,  0.0865,  ..., -0.1053,  0.1029,  0.0856],\n",
      "        [-0.0803,  0.0811, -0.0938,  ...,  0.0920, -0.1115,  0.0943],\n",
      "        ...,\n",
      "        [ 0.1394, -0.1390,  0.1343,  ..., -0.1392, -0.1395,  0.1501],\n",
      "        [ 0.1394, -0.1390,  0.1343,  ..., -0.1392, -0.1395,  0.1501],\n",
      "        [ 0.1394, -0.1390,  0.1343,  ..., -0.1392, -0.1395,  0.1501]],\n",
      "       device='cuda:0', requires_grad=True)\n",
      "7.0299811363220215\n"
     ]
    }
   ],
   "source": [
    "optimizer = TritonAdamW(model.parameters(), exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16)\n",
    "torch.cuda.empty_cache()\n",
    "inp_ids = torch.arange(128).reshape(4,-1).cuda()\n",
    "for _ in tqdm(range(iters)):\n",
    "    out = model(inp_ids)\n",
    "    out.logits.mean().backward()\n",
    "    optimizer.step()\n",
    "print(p) # 刷新再跑，p应该差不多\n",
    "ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)\n",
    "print(ms)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. 混合精度训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/workspace/mdy/miniforge/envs/mdy/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/mnt/workspace/mdy/miniforge/envs/mdy/lib/python3.10/site-packages/_distutils_hack/__init__.py:53: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
      "  warnings.warn(\n",
      "/mnt/workspace/mdy/miniforge/envs/mdy/lib/python3.10/site-packages/_distutils_hack/__init__.py:53: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "import triton\n",
    "import triton.language as tl\n",
    "from TritonAdam import TritonAdamW\n",
    "from apex.optimizers import FusedAdam as ApexFusedAdam\n",
    "from transformer_engine.pytorch.optimizers import FusedAdam as TEFusedAdam\n",
    "import os\n",
    "os.environ['TRITON_PRINT_AUTOTUNING'] = '1'\n",
    "torch.cuda.empty_cache()\n",
    "import torch._inductor.runtime.hints"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载模型\n",
    "- 加载bf16的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = '/mnt/workspace/mdy/models/Qwen2.5-0.5B-Instruct'\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda()\n",
    "iters = 100\n",
    "for p in model.parameters():\n",
    "    break\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Apex\n",
    "- apex只有标准的混合精度训练，fp32的master weight和1，2阶动量，bf16/fp16的model weight和grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13.046334266662598\n"
     ]
    }
   ],
   "source": [
    "optimizer = ApexFusedAdam(model.parameters(), capturable=True, master_weights=True)\n",
    "torch.cuda.empty_cache()\n",
    "inp_ids = torch.arange(128).reshape(4,-1).cuda()\n",
    "for _ in tqdm(range(iters)):\n",
    "    out = model(inp_ids)\n",
    "    out.logits.mean().backward()\n",
    "    optimizer.step()\n",
    "    break\n",
    "ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)\n",
    "print(ms)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transformer Engine\n",
    "- 最新的te，支持多种精度（可以点进去看下），比如1，2阶动量支持fp16，int8之类的，但都需要进行scale，但是不支持bf16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 标准版"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.840949058532715\n"
     ]
    }
   ],
   "source": [
    "optimizer = TEFusedAdam(model.parameters(), master_weights=True)\n",
    "torch.cuda.empty_cache()\n",
    "inp_ids = torch.arange(128).reshape(4,-1).cuda()\n",
    "for _ in tqdm(range(iters)):\n",
    "    out = model(inp_ids)\n",
    "    out.logits.mean().backward()\n",
    "    optimizer.step()\n",
    "    break\n",
    "ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)\n",
    "print(ms)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### fp16的1，2阶动量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64.45989990234375\n"
     ]
    }
   ],
   "source": [
    "optimizer = TEFusedAdam(model.parameters(), \n",
    "                        exp_avg_dtype=torch.float16,\n",
    "                        exp_avg_sq_dtype=torch.float16,\n",
    "                        master_weights=True)\n",
    "torch.cuda.empty_cache()\n",
    "inp_ids = torch.arange(128).reshape(4,-1).cuda()\n",
    "for _ in tqdm(range(iters)):\n",
    "    out = model(inp_ids)\n",
    "    out.logits.mean().backward()\n",
    "    optimizer.step()\n",
    "    break\n",
    "ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)\n",
    "print(ms)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### fp16的1，2阶动量 + fp32的grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64.99529266357422\n"
     ]
    }
   ],
   "source": [
    "optimizer = TEFusedAdam(model.parameters(), \n",
    "                        exp_avg_dtype=torch.float16,\n",
    "                        exp_avg_sq_dtype=torch.float16,\n",
    "                        use_decoupled_grad=True,\n",
    "                        master_weights=True)\n",
    "torch.cuda.empty_cache()\n",
    "inp_ids = torch.arange(128).reshape(4,-1).cuda()\n",
    "for _ in tqdm(range(iters)):\n",
    "    out = model(inp_ids)\n",
    "    out.logits.mean().backward()\n",
    "    optimizer.step()\n",
    "    break\n",
    "\n",
    "# grad必须和param的精度是一样的，如果是bf16的p，使用fp32的g，那么就需要使用其它属性进行存储\n",
    "for p in model.parameters():\n",
    "    p.decoupled_grad = p.grad.float()\n",
    "    p.grad = None\n",
    "ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)\n",
    "print(ms)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Triton Adam\n",
    "- 基本就是对着TE中的进行写的，接口基本都差不多，无缝衔接Megatron框架\n",
    "- 目前支持master weight是fp32，model weight bf16， grad fp32 和 bf16都可以，1，2阶动量bf16或者fp32。无多余功能，基本满足训练需求"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 标准版"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "finish_custom_init, p_dtype: torch.bfloat16, master_p_dtype: torch.float32\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:01<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.474020957946777\n"
     ]
    }
   ],
   "source": [
    "optimizer = TritonAdamW(model.parameters(), master_weights=True)\n",
    "torch.cuda.empty_cache()\n",
    "inp_ids = torch.arange(128).reshape(4,-1).cuda()\n",
    "for _ in tqdm(range(iters)):\n",
    "    out = model(inp_ids)\n",
    "    out.logits.mean().backward()\n",
    "    optimizer.step()\n",
    "    break\n",
    "ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)\n",
    "print(ms)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### bf16的1，2阶动量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "finish_custom_init, p_dtype: torch.bfloat16, master_p_dtype: torch.float32\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.561221122741699\n"
     ]
    }
   ],
   "source": [
    "optimizer = TritonAdamW(model.parameters(), \n",
    "                        exp_avg_dtype=torch.bfloat16,\n",
    "                        exp_avg_sq_dtype=torch.bfloat16,\n",
    "                        master_weights=True)\n",
    "torch.cuda.empty_cache()\n",
    "inp_ids = torch.arange(128).reshape(4,-1).cuda()\n",
    "for _ in tqdm(range(iters)):\n",
    "    out = model(inp_ids)\n",
    "    out.logits.mean().backward()\n",
    "    optimizer.step()\n",
    "    break\n",
    "ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)\n",
    "print(ms)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### bf16的1，2阶动量 + fp32的grad\n",
    "- 这个就是deepseekv3中的配置\n",
    "- 在megatron中，它会使用一个fp32的grad buffer去存储梯度\n",
    "- 所有micro batch的梯度都加到这个buffer中，通过hook实现，下面是伪代码\n",
    "- grad_buffer += p.grad\n",
    "- p.grad = None\n",
    "- 当所有micro batch都计算完后\n",
    "- optimizer.model_p.decoupled_grad = grad_buffer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "finish_custom_init, p_dtype: torch.bfloat16, master_p_dtype: torch.float32\n",
      "7.5503644943237305\n"
     ]
    }
   ],
   "source": [
    "optimizer = TritonAdamW(model.parameters(), \n",
    "                        exp_avg_dtype=torch.bfloat16,\n",
    "                        exp_avg_sq_dtype=torch.bfloat16,\n",
    "                        use_decoupled_grad=True,\n",
    "                        master_weights=True)\n",
    "torch.cuda.empty_cache()\n",
    "inp_ids = torch.arange(128).reshape(4,-1).cuda()\n",
    "for _ in tqdm(range(iters)):\n",
    "    out = model(inp_ids)\n",
    "    out.logits.mean().backward()\n",
    "    optimizer.step()\n",
    "    break\n",
    "\n",
    "# grad必须和param的精度是一样的，如果是bf16的p，使用fp32的g，那么就需要使用其它属性进行存储\n",
    "for p in model.parameters():\n",
    "    p.decoupled_grad = p.grad.float()\n",
    "    p.grad = None\n",
    "ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)\n",
    "print(ms)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
